WDS-QA-Bot / app.py
jeongsk's picture
Update app.py
d8cad3f verified
import os
import pickle
import streamlit as st
from dotenv import load_dotenv
from laas import ChatLaaS
from langchain.embeddings import CacheBackedEmbeddings
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain.retrievers.document_compressors import (
CrossEncoderReranker,
FlashrankRerank,
)
from langchain.storage import LocalFileStore
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_community.document_loaders.generic import GenericLoader
from langchain_community.document_loaders.parsers.language.language_parser import (
LanguageParser,
)
from langchain_community.retrievers import BM25Retriever
from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.vectorstores import VectorStore
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
# Load environment variables
load_dotenv()
# Set up environment variables
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "Code QA Bot"
@st.cache_resource
def setup_embeddings_and_db(project_folder: str): # Note the underscore before 'docs'
CACHE_ROOT_PATH = os.path.join(os.path.expanduser("~"), ".cache")
CACHE_MODELS_PATH = os.path.join(CACHE_ROOT_PATH, "models")
CACHE_EMBEDDINGS_PATH = os.path.join(CACHE_ROOT_PATH, "embeddings")
if not os.path.exists(CACHE_MODELS_PATH):
os.makedirs(CACHE_MODELS_PATH)
if not os.path.exists(CACHE_EMBEDDINGS_PATH):
os.makedirs(CACHE_EMBEDDINGS_PATH)
store = LocalFileStore(CACHE_EMBEDDINGS_PATH)
model_name = "BAAI/bge-m3"
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": False}
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
cache_folder=CACHE_MODELS_PATH,
multi_process=False,
show_progress=True,
)
cached_embeddings = CacheBackedEmbeddings.from_bytes_store(
embeddings,
store,
namespace=embeddings.model_name,
)
FAISS_DB_INDEX = os.path.join(project_folder, "langchain_faiss")
db = FAISS.load_local(
FAISS_DB_INDEX, # ๋กœ๋“œํ•  FAISS ์ธ๋ฑ์Šค์˜ ๋””๋ ‰ํ† ๋ฆฌ ์ด๋ฆ„
cached_embeddings, # ์ž„๋ฒ ๋”ฉ ์ •๋ณด๋ฅผ ์ œ๊ณต
allow_dangerous_deserialization=True, # ์—ญ์ง๋ ฌํ™”๋ฅผ ํ—ˆ์šฉํ•˜๋Š” ์˜ต์…˜
)
return db
# Function to set up retrievers and chain
@st.cache_resource
def setup_retrievers_and_chain(
_db: VectorStore, project_folder: str
): # Note the underscores
faiss_retriever = _db.as_retriever(search_type="mmr", search_kwargs={"k": 20})
bm25_retriever_path = os.path.join(project_folder, "bm25_retriever.pkl")
with open(bm25_retriever_path, "rb") as f:
bm25_retriever = pickle.load(f)
bm25_retriever.k = 20
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_retriever],
weights=[0.6, 0.4],
search_type="mmr",
)
model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
compressor = CrossEncoderReranker(model=model, top_n=5)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=ensemble_retriever,
)
laas = ChatLaaS(
project=st.secrets["LAAS_PROJECT"],
api_key=st.secrets["LAAS_API_KEY"],
hash=st.secrets["LAAS_HASH"],
)
rag_chain = (
{
"context": compression_retriever | RunnableLambda(lambda x: str(x)),
"question": RunnablePassthrough(),
}
| RunnableLambda(
lambda x: laas.invoke(
"", params={"context": x["context"], "question": x["question"]}
)
)
| StrOutputParser()
)
return rag_chain
def sidebar_content():
st.sidebar.title("์‚ฌ์šฉ ๊ฐ€์ด๋“œ")
st.sidebar.info(
"""
1. ์™ผ์ชฝ ํ…์ŠคํŠธ ์˜์—ญ์— ์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”.
2. '๋‹ต๋ณ€ ์ƒ์„ฑ' ๋ฒ„ํŠผ์„ ํด๋ฆญํ•˜์„ธ์š”.
3. ๋‹ต๋ณ€์ด ์•„๋ž˜์— ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค.
4. ์ƒˆ๋กœ์šด ์งˆ๋ฌธ์„ ํ•˜๋ ค๋ฉด '๋‹ต๋ณ€ ์ดˆ๊ธฐํ™”' ๋ฒ„ํŠผ์„ ์‚ฌ์šฉํ•˜์„ธ์š”.
"""
)
if st.sidebar.button("๋‹ต๋ณ€ ์ดˆ๊ธฐํ™”", key="reset"):
st.session_state.answer = ""
st.experimental_rerun()
def main():
st.set_page_config(page_title="WDS QA ๋ด‡", page_icon="๐Ÿค–", layout="wide")
sidebar_content()
st.title("๐Ÿค– WDS QA ๋ด‡")
st.subheader("์งˆ๋ฌธํ•˜๊ธฐ")
user_question = st.text_area("์ฝ”๋“œ์— ๋Œ€ํ•ด ๊ถ๊ธˆํ•œ ์ ์„ ๋ฌผ์–ด๋ณด์„ธ์š”:", height=100)
if st.button("๋‹ต๋ณ€ ์ƒ์„ฑ", key="generate"):
if user_question:
with st.spinner("๋‹ต๋ณ€์„ ์ƒ์„ฑ ์ค‘์ž…๋‹ˆ๋‹ค..."):
project_folder = "wds"
db = setup_embeddings_and_db(project_folder)
rag_chain = setup_retrievers_and_chain(db, project_folder)
response = rag_chain.invoke(user_question)
st.session_state.answer = response
else:
st.warning("์งˆ๋ฌธ์„ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”.")
if "answer" in st.session_state and st.session_state.answer:
st.subheader("๋‹ต๋ณ€")
st.markdown(st.session_state.answer)
st.markdown("---")
st.caption("ยฉ 2023 WDS QA ๋ด‡. ๋ชจ๋“  ๊ถŒ๋ฆฌ ๋ณด์œ .")
if __name__ == "__main__":
main()