Spaces:
Sleeping
Sleeping
import os | |
import pickle | |
import streamlit as st | |
from dotenv import load_dotenv | |
from laas import ChatLaaS | |
from langchain.embeddings import CacheBackedEmbeddings | |
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever | |
from langchain.retrievers.document_compressors import ( | |
CrossEncoderReranker, | |
FlashrankRerank, | |
) | |
from langchain.storage import LocalFileStore | |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
from langchain_community.document_loaders.generic import GenericLoader | |
from langchain_community.document_loaders.parsers.language.language_parser import ( | |
LanguageParser, | |
) | |
from langchain_community.retrievers import BM25Retriever | |
from langchain_community.vectorstores import FAISS | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnableLambda, RunnablePassthrough | |
from langchain_core.vectorstores import VectorStore | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter | |
# Load environment variables | |
load_dotenv() | |
# Set up environment variables | |
os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
os.environ["LANGCHAIN_PROJECT"] = "Code QA Bot" | |
def setup_embeddings_and_db(project_folder: str): # Note the underscore before 'docs' | |
CACHE_ROOT_PATH = os.path.join(os.path.expanduser("~"), ".cache") | |
CACHE_MODELS_PATH = os.path.join(CACHE_ROOT_PATH, "models") | |
CACHE_EMBEDDINGS_PATH = os.path.join(CACHE_ROOT_PATH, "embeddings") | |
if not os.path.exists(CACHE_MODELS_PATH): | |
os.makedirs(CACHE_MODELS_PATH) | |
if not os.path.exists(CACHE_EMBEDDINGS_PATH): | |
os.makedirs(CACHE_EMBEDDINGS_PATH) | |
store = LocalFileStore(CACHE_EMBEDDINGS_PATH) | |
model_name = "BAAI/bge-m3" | |
model_kwargs = {"device": "cpu"} | |
encode_kwargs = {"normalize_embeddings": False} | |
embeddings = HuggingFaceEmbeddings( | |
model_name=model_name, | |
model_kwargs=model_kwargs, | |
encode_kwargs=encode_kwargs, | |
cache_folder=CACHE_MODELS_PATH, | |
multi_process=False, | |
show_progress=True, | |
) | |
cached_embeddings = CacheBackedEmbeddings.from_bytes_store( | |
embeddings, | |
store, | |
namespace=embeddings.model_name, | |
) | |
FAISS_DB_INDEX = os.path.join(project_folder, "langchain_faiss") | |
db = FAISS.load_local( | |
FAISS_DB_INDEX, # ๋ก๋ํ FAISS ์ธ๋ฑ์ค์ ๋๋ ํ ๋ฆฌ ์ด๋ฆ | |
cached_embeddings, # ์๋ฒ ๋ฉ ์ ๋ณด๋ฅผ ์ ๊ณต | |
allow_dangerous_deserialization=True, # ์ญ์ง๋ ฌํ๋ฅผ ํ์ฉํ๋ ์ต์ | |
) | |
return db | |
# Function to set up retrievers and chain | |
def setup_retrievers_and_chain( | |
_db: VectorStore, project_folder: str | |
): # Note the underscores | |
faiss_retriever = _db.as_retriever(search_type="mmr", search_kwargs={"k": 20}) | |
bm25_retriever_path = os.path.join(project_folder, "bm25_retriever.pkl") | |
with open(bm25_retriever_path, "rb") as f: | |
bm25_retriever = pickle.load(f) | |
bm25_retriever.k = 20 | |
ensemble_retriever = EnsembleRetriever( | |
retrievers=[bm25_retriever, faiss_retriever], | |
weights=[0.6, 0.4], | |
search_type="mmr", | |
) | |
model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3") | |
compressor = CrossEncoderReranker(model=model, top_n=5) | |
compression_retriever = ContextualCompressionRetriever( | |
base_compressor=compressor, | |
base_retriever=ensemble_retriever, | |
) | |
laas = ChatLaaS( | |
project=st.secrets["LAAS_PROJECT"], | |
api_key=st.secrets["LAAS_API_KEY"], | |
hash=st.secrets["LAAS_HASH"], | |
) | |
rag_chain = ( | |
{ | |
"context": compression_retriever | RunnableLambda(lambda x: str(x)), | |
"question": RunnablePassthrough(), | |
} | |
| RunnableLambda( | |
lambda x: laas.invoke( | |
"", params={"context": x["context"], "question": x["question"]} | |
) | |
) | |
| StrOutputParser() | |
) | |
return rag_chain | |
def sidebar_content(): | |
st.sidebar.title("์ฌ์ฉ ๊ฐ์ด๋") | |
st.sidebar.info( | |
""" | |
1. ์ผ์ชฝ ํ ์คํธ ์์ญ์ ์ง๋ฌธ์ ์ ๋ ฅํ์ธ์. | |
2. '๋ต๋ณ ์์ฑ' ๋ฒํผ์ ํด๋ฆญํ์ธ์. | |
3. ๋ต๋ณ์ด ์๋์ ํ์๋ฉ๋๋ค. | |
4. ์๋ก์ด ์ง๋ฌธ์ ํ๋ ค๋ฉด '๋ต๋ณ ์ด๊ธฐํ' ๋ฒํผ์ ์ฌ์ฉํ์ธ์. | |
""" | |
) | |
if st.sidebar.button("๋ต๋ณ ์ด๊ธฐํ", key="reset"): | |
st.session_state.answer = "" | |
st.experimental_rerun() | |
def main(): | |
st.set_page_config(page_title="WDS QA ๋ด", page_icon="๐ค", layout="wide") | |
sidebar_content() | |
st.title("๐ค WDS QA ๋ด") | |
st.subheader("์ง๋ฌธํ๊ธฐ") | |
user_question = st.text_area("์ฝ๋์ ๋ํด ๊ถ๊ธํ ์ ์ ๋ฌผ์ด๋ณด์ธ์:", height=100) | |
if st.button("๋ต๋ณ ์์ฑ", key="generate"): | |
if user_question: | |
with st.spinner("๋ต๋ณ์ ์์ฑ ์ค์ ๋๋ค..."): | |
project_folder = "wds" | |
db = setup_embeddings_and_db(project_folder) | |
rag_chain = setup_retrievers_and_chain(db, project_folder) | |
response = rag_chain.invoke(user_question) | |
st.session_state.answer = response | |
else: | |
st.warning("์ง๋ฌธ์ ์ ๋ ฅํด์ฃผ์ธ์.") | |
if "answer" in st.session_state and st.session_state.answer: | |
st.subheader("๋ต๋ณ") | |
st.markdown(st.session_state.answer) | |
st.markdown("---") | |
st.caption("ยฉ 2023 WDS QA ๋ด. ๋ชจ๋ ๊ถ๋ฆฌ ๋ณด์ .") | |
if __name__ == "__main__": | |
main() | |