Spaces:
Sleeping
Sleeping
from llama_index.core import Settings, VectorStoreIndex, StorageContext, load_index_from_storage | |
from llama_index.core.embeddings import BaseEmbedding | |
from llama_index.llms.mistralai import MistralAI | |
from llama_index.core.node_parser import SentenceSplitter | |
from llama_index.core import SimpleDirectoryReader | |
from llama_index.core import PromptTemplate | |
# from pydantic import PrivateAttr | |
# import requests | |
from typing import List, Optional, Union | |
# from llama_index.core.embeddings.utils import BaseEmbedding | |
from llama_index.embeddings.huggingface import HuggingFaceInferenceAPIEmbedding | |
# from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
import streamlit as st | |
from functools import lru_cache | |
import pickle | |
import os | |
mistral_api_key = os.getenv("mistral_api_key") | |
class QASystem: | |
def __init__(self, | |
mistral_api_key: str = mistral_api_key, | |
data_dir: str = "./data", | |
storage_dir: str = "./index_llama_136_multilingual-e5-large", | |
model_temperature: float = 0.002): | |
self.data_dir = data_dir | |
self.storage_dir = storage_dir | |
# Initialize embedding model with API | |
# api_key = | |
self.embedding_model = HuggingFaceInferenceAPIEmbedding( | |
model_name="intfloat/multilingual-e5-large", | |
) | |
# self.embedding_model = HuggingFaceEmbedding(model_name="intfloat/multilingual-e5-large",trust_remote_code=True) | |
self.llm = MistralAI( | |
model="mistral-large-latest", | |
api_key=mistral_api_key, | |
temperature=model_temperature, | |
max_tokens=1024 | |
) | |
self._configure_settings() | |
# self.create_index() | |
self.index = self.load_index() # Define index here | |
def _configure_settings(self): | |
Settings.llm = self.llm | |
Settings.embed_model = self.embedding_model | |
def create_index(self): | |
print("creating index") | |
documents = SimpleDirectoryReader(self.data_dir).load_data() | |
node_parser = SentenceSplitter(chunk_size=206, chunk_overlap=0) | |
nodes = node_parser.get_nodes_from_documents(documents, show_progress=True) | |
sentence_index = VectorStoreIndex(nodes, show_progress=True) | |
sentence_index.storage_context.persist(self.storage_dir) | |
# # Save the index to a pickle file | |
# with open(f"{self.storage_dir}/index.pkl", "wb") as f: | |
# pickle.dump(sentence_index, f) | |
return sentence_index | |
def load_index(self): | |
# with open(f'{self.storage_dir}/index.pkl', 'rb') as f: | |
# sentence_index = pickle.load(f) | |
storage_context = StorageContext.from_defaults(persist_dir=self.storage_dir) | |
return load_index_from_storage(storage_context, embed_model=self.embedding_model) | |
def create_query_engine(self): | |
template = """ | |
استخدم المعلومات التالية للإجابة على السؤال في النهاية. إذا لم تكن تعرف الإجابة، فقل فقط أنك لا تعرف، لا تحاول اختلاق إجابة. | |
{context_str} | |
السؤال: {query_str} | |
الإجابة بالعربية: | |
""" | |
prompt = PromptTemplate(template=template) | |
return self.index.as_query_engine( | |
similarity_top_k=10, | |
streaming=True, | |
text_qa_template=prompt, | |
response_mode="tree_summarize", #tree_summarize, simple_summarize, compact | |
embed_model=self.embedding_model | |
) | |
def query(self, question: str): | |
query_engine = self.create_query_engine() | |
response = query_engine.query(question) | |
return response#.print_response_stream() | |
# Utilisation de singleton pour éviter les réinitialisations multiples | |
# @st.cache_resource | |
def get_qa_system(): | |
return QASystem() | |
def main(): | |
st.markdown(""" | |
<style> | |
@import url('https://fonts.googleapis.com/css2?family=Noto+Kufi+Arabic:wght@100;200;300;400;500;600;700;800;900&display=swap'); | |
/* Application globale de la police */ | |
* { | |
font-family: 'Noto Kufi Arabic', sans-serif !important; | |
} | |
body { | |
text-align: right; | |
font-family: 'Noto Kufi Arabic', sans-serif !important; | |
} | |
/* Style pour tous les textes */ | |
p, div, span, button, input, label, h1, h2, h3, h4, h5, h6 { | |
font-family: 'Noto Kufi Arabic', sans-serif !important; | |
} | |
/* Titre principal avec taille réduite */ | |
h1 { | |
font-size: 1.2em !important; | |
margin-bottom: 0.5em !important; | |
text-align: center; | |
padding: 0.3px; | |
} | |
.css-1h9b9rq.e1tzin5v0 { | |
direction: rtl; | |
text-align: right; | |
font-family: 'Noto Kufi Arabic', sans-serif !important; | |
} | |
/* Style pour l'expandeur */ | |
.streamlit-expanderContent, div[data-testid="stExpander"] { | |
direction: rtl !important; | |
text-align: right !important; | |
font-family: 'Noto Kufi Arabic', sans-serif !important; | |
} | |
/* Style pour les boutons de l'expandeur */ | |
button[kind="secondary"] { | |
direction: rtl !important; | |
text-align: right !important; | |
width: 100% !important; | |
font-family: 'Noto Kufi Arabic', sans-serif !important; | |
font-weight: 30 !important; | |
} | |
/* Style pour tous les éléments de texte */ | |
p, div { | |
direction: rtl !important; | |
text-align: right !important; | |
font-family: 'Noto Kufi Arabic', sans-serif !important; | |
} | |
/* Style pour les bullet points */ | |
ul, li { | |
direction: rtl !important; | |
text-align: right !important; | |
margin-right: 20px !important; | |
margin-left: 0 !important; | |
font-family: 'Noto Kufi Arabic', sans-serif !important; | |
} | |
.stTextInput, .stButton { | |
margin-left: auto; | |
margin-right: 0; | |
font-family: 'Noto Kufi Arabic', sans-serif !important; | |
} | |
.stTextInput { | |
width: 100% !important; | |
direction: rtl; | |
text-align: right; | |
font-family: 'Noto Kufi Arabic', sans-serif !important; | |
} | |
/* Force RTL sur tous les conteneurs */ | |
.element-container, .stMarkdown { | |
direction: rtl !important; | |
text-align: right !important; | |
font-family: 'Noto Kufi Arabic', sans-serif !important; | |
} | |
/* Style spécifique pour l'expandeur des sources */ | |
.css-1fcdlhc, .css-1629p8f { | |
direction: rtl !important; | |
text-align: right !important; | |
font-family: 'Noto Kufi Arabic', sans-serif !important; | |
} | |
/* Style pour le titre */ | |
.stTitle { | |
font-family: 'Noto Kufi Arabic', sans-serif !important; | |
font-weight: 700 !important; | |
} | |
/* Style pour les boutons */ | |
.stButton>button { | |
font-family: 'Noto Kufi Arabic', sans-serif !important; | |
font-weight: 500 !important; | |
} | |
/* Style pour les champs de texte */ | |
.stTextInput>div>div>input { | |
font-family: 'Noto Kufi Arabic', sans-serif !important; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.title("هذا تطبيق للاجابة عن الاسئلة المتعلقة بالقانون المغربي ") | |
st.title("حاليا يضم 136 قانونا") | |
qa_system = get_qa_system() | |
question = st.text_input("اطرح سؤالك :",placeholder=None) | |
if st.button("بحث"): | |
if question: | |
response_container = st.empty() | |
def stream_response(token): | |
if 'current_response' not in st.session_state: | |
st.session_state.current_response = "" | |
st.session_state.current_response += token | |
response_container.markdown(st.session_state.current_response, unsafe_allow_html=True) | |
try: | |
query_engine = qa_system.create_query_engine() | |
st.session_state.current_response = "" | |
response = query_engine.query(question) | |
full_response = "" | |
for token in response.response_gen: | |
full_response += token | |
stream_response(token) | |
if hasattr(response, 'source_nodes'): | |
st.markdown(""" | |
<div style="direction: rtl !important; text-align: right !important; font-family: 'Noto Kufi Arabic', sans-serif !important;"> | |
<div class="streamlit-expanderHeader"> | |
المصادر | |
</div> | |
</div> | |
""", unsafe_allow_html=True) | |
with st.expander(""): | |
for node in response.source_nodes: | |
st.markdown(f""" | |
<div style="direction: rtl !important; text-align: right !important; font-family: 'Noto Kufi Arabic', sans-serif !important;"> | |
<p style="text-align: right !important;">مصادر الجواب : {node.metadata.get('file_name', 'Unknown')}</p> | |
<p style="text-align: right !important;">Extrait: {node.text[:]}</p> | |
</div> | |
""", unsafe_allow_html=True) | |
except Exception as e: | |
st.error(f"Une erreur s'est produite : {e}") | |
if __name__ == "__main__": | |
main() |