|
|
|
import streamlit as st |
|
import os |
|
import yaml |
|
from dotenv import load_dotenv |
|
import torch |
|
from src.generator import answer_with_rag |
|
from ragatouille import RAGPretrainedModel |
|
from src.data_preparation import split_documents |
|
from src.embeddings import init_embedding_model |
|
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA |
|
|
|
from transformers import pipeline |
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from src.retriever import init_vectorDB_from_doc, retriever |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
from langchain_community.vectorstores import FAISS |
|
import faiss |
|
def load_config(): |
|
with open("./config.yml","r") as file_object: |
|
try: |
|
cfg=yaml.safe_load(file_object) |
|
|
|
except yaml.YAMLError as exc: |
|
logger.error(str(exc)) |
|
raise |
|
else: |
|
return cfg |
|
|
|
cfg= load_config() |
|
|
|
|
|
|
|
|
|
EMBEDDING_MODEL_NAME=cfg['EMBEDDING_MODEL_NAME'] |
|
DATA_FILE_PATH=cfg['DATA_FILE_PATH'] |
|
READER_MODEL_NAME=cfg['READER_MODEL_NAME'] |
|
RERANKER_MODEL_NAME=cfg['RERANKER_MODEL_NAME'] |
|
VECTORDB_PATH=cfg['VECTORDB_PATH'] |
|
|
|
|
|
def main(): |
|
st.title("Un RAG pour interroger le Collège de Pédiatrie 2024") |
|
user_query = st.text_input("Entrez votre question:") |
|
|
|
if "KNOWLEDGE_VECTOR_DATABASE" not in st.session_state: |
|
|
|
|
|
st.session_state.loader = PyPDFLoader(DATA_FILE_PATH) |
|
|
|
st.session_state.raw_document_base = st.session_state.loader.load() |
|
st.session_state.MARKDOWN_SEPARATORS = [ |
|
"\n#{1,6} ", |
|
"```\n", |
|
"\n\\*\\*\\*+\n", |
|
"\n---+\n", |
|
"\n___+\n", |
|
"\n\n", |
|
"\n", |
|
" ", |
|
"",] |
|
st.session_state.docs_processed = split_documents( |
|
400, |
|
st.session_state.raw_document_base, |
|
|
|
separator=st.session_state.MARKDOWN_SEPARATORS |
|
) |
|
st.session_state.embedding_model=NVIDIAEmbeddings(model="NV-Embed-QA", truncate="END") |
|
st.session_state.KNOWLEDGE_VECTOR_DATABASE= init_vectorDB_from_doc(st.session_state.docs_processed, |
|
st.session_state.embedding_model) |
|
if (user_query) and (st.button("Get Answer")): |
|
num_doc_before_rerank=5 |
|
st.session_state.retriever= st.session_state.KNOWLEDGE_VECTOR_DATABASE.as_retriever(search_type="similarity", |
|
search_kwargs={"k": num_doc_before_rerank}) |
|
|
|
st.write("### Please wait while we are getting the answer.....") |
|
llm = ChatNVIDIA( |
|
model=READER_MODEL_NAME, |
|
api_key= os.getenv("NVIDIA_API_KEY"), |
|
temperature=0.2, |
|
top_p=0.7, |
|
max_tokens=1024, |
|
) |
|
answer, relevant_docs = answer_with_rag(query=user_query, llm=llm, retriever=st.session_state.retriever) |
|
st.write("### Answer:") |
|
st.write(answer) |
|
|
|
st.write("### Relevant Documents:") |
|
for i, doc in enumerate(relevant_docs): |
|
st.write(f"Document {i}:\n{doc}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |