File size: 5,580 Bytes
e7055d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7c7d7c
e7055d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8cad3f
e7055d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7c7d7c
 
 
 
 
 
 
 
 
 
e7055d3
d7c7d7c
e7055d3
d7c7d7c
e7055d3
 
d7c7d7c
 
 
 
e7055d3
d7c7d7c
e7055d3
d7c7d7c
 
e7055d3
d7c7d7c
e7055d3
d7c7d7c
 
 
 
e7055d3
 
d7c7d7c
 
 
 
 
 
e7055d3
d7c7d7c
 
e7055d3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
import pickle

import streamlit as st
from dotenv import load_dotenv
from laas import ChatLaaS
from langchain.embeddings import CacheBackedEmbeddings
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain.retrievers.document_compressors import (
    CrossEncoderReranker,
    FlashrankRerank,
)
from langchain.storage import LocalFileStore
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_community.document_loaders.generic import GenericLoader
from langchain_community.document_loaders.parsers.language.language_parser import (
    LanguageParser,
)
from langchain_community.retrievers import BM25Retriever
from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.vectorstores import VectorStore
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter

# Load environment variables
load_dotenv()

# Set up environment variables
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "Code QA Bot"


@st.cache_resource
def setup_embeddings_and_db(project_folder: str):  # Note the underscore before 'docs'
    CACHE_ROOT_PATH = os.path.join(os.path.expanduser("~"), ".cache")
    CACHE_MODELS_PATH = os.path.join(CACHE_ROOT_PATH, "models")
    CACHE_EMBEDDINGS_PATH = os.path.join(CACHE_ROOT_PATH, "embeddings")

    if not os.path.exists(CACHE_MODELS_PATH):
        os.makedirs(CACHE_MODELS_PATH)
    if not os.path.exists(CACHE_EMBEDDINGS_PATH):
        os.makedirs(CACHE_EMBEDDINGS_PATH)

    store = LocalFileStore(CACHE_EMBEDDINGS_PATH)

    model_name = "BAAI/bge-m3"
    model_kwargs = {"device": "cpu"}
    encode_kwargs = {"normalize_embeddings": False}
    embeddings = HuggingFaceEmbeddings(
        model_name=model_name,
        model_kwargs=model_kwargs,
        encode_kwargs=encode_kwargs,
        cache_folder=CACHE_MODELS_PATH,
        multi_process=False,
        show_progress=True,
    )

    cached_embeddings = CacheBackedEmbeddings.from_bytes_store(
        embeddings,
        store,
        namespace=embeddings.model_name,
    )

    FAISS_DB_INDEX = os.path.join(project_folder, "langchain_faiss")
    db = FAISS.load_local(
        FAISS_DB_INDEX,  # ๋กœ๋“œํ•  FAISS ์ธ๋ฑ์Šค์˜ ๋””๋ ‰ํ† ๋ฆฌ ์ด๋ฆ„
        cached_embeddings,  # ์ž„๋ฒ ๋”ฉ ์ •๋ณด๋ฅผ ์ œ๊ณต
        allow_dangerous_deserialization=True,  # ์—ญ์ง๋ ฌํ™”๋ฅผ ํ—ˆ์šฉํ•˜๋Š” ์˜ต์…˜
    )

    return db


# Function to set up retrievers and chain
@st.cache_resource
def setup_retrievers_and_chain(
    _db: VectorStore, project_folder: str
):  # Note the underscores
    faiss_retriever = _db.as_retriever(search_type="mmr", search_kwargs={"k": 20})

    bm25_retriever_path = os.path.join(project_folder, "bm25_retriever.pkl")
    with open(bm25_retriever_path, "rb") as f:
        bm25_retriever = pickle.load(f)
        bm25_retriever.k = 20

    ensemble_retriever = EnsembleRetriever(
        retrievers=[bm25_retriever, faiss_retriever],
        weights=[0.6, 0.4],
        search_type="mmr",
    )

    model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
    compressor = CrossEncoderReranker(model=model, top_n=5)
    compression_retriever = ContextualCompressionRetriever(
        base_compressor=compressor,
        base_retriever=ensemble_retriever,
    )

    laas = ChatLaaS(
        project=st.secrets["LAAS_PROJECT"],
        api_key=st.secrets["LAAS_API_KEY"],
        hash=st.secrets["LAAS_HASH"],
    )

    rag_chain = (
        {
            "context": compression_retriever | RunnableLambda(lambda x: str(x)),
            "question": RunnablePassthrough(),
        }
        | RunnableLambda(
            lambda x: laas.invoke(
                "", params={"context": x["context"], "question": x["question"]}
            )
        )
        | StrOutputParser()
    )

    return rag_chain


def sidebar_content():
    st.sidebar.title("์‚ฌ์šฉ ๊ฐ€์ด๋“œ")
    st.sidebar.info(
        """
    1. ์™ผ์ชฝ ํ…์ŠคํŠธ ์˜์—ญ์— ์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”.
    2. '๋‹ต๋ณ€ ์ƒ์„ฑ' ๋ฒ„ํŠผ์„ ํด๋ฆญํ•˜์„ธ์š”.
    3. ๋‹ต๋ณ€์ด ์•„๋ž˜์— ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค.
    4. ์ƒˆ๋กœ์šด ์งˆ๋ฌธ์„ ํ•˜๋ ค๋ฉด '๋‹ต๋ณ€ ์ดˆ๊ธฐํ™”' ๋ฒ„ํŠผ์„ ์‚ฌ์šฉํ•˜์„ธ์š”.
    """
    )

    if st.sidebar.button("๋‹ต๋ณ€ ์ดˆ๊ธฐํ™”", key="reset"):
        st.session_state.answer = ""
        st.experimental_rerun()


def main():
    st.set_page_config(page_title="WDS QA ๋ด‡", page_icon="๐Ÿค–", layout="wide")

    sidebar_content()

    st.title("๐Ÿค– WDS QA ๋ด‡")

    st.subheader("์งˆ๋ฌธํ•˜๊ธฐ")
    user_question = st.text_area("์ฝ”๋“œ์— ๋Œ€ํ•ด ๊ถ๊ธˆํ•œ ์ ์„ ๋ฌผ์–ด๋ณด์„ธ์š”:", height=100)

    if st.button("๋‹ต๋ณ€ ์ƒ์„ฑ", key="generate"):
        if user_question:
            with st.spinner("๋‹ต๋ณ€์„ ์ƒ์„ฑ ์ค‘์ž…๋‹ˆ๋‹ค..."):
                project_folder = "wds"
                db = setup_embeddings_and_db(project_folder)
                rag_chain = setup_retrievers_and_chain(db, project_folder)
                response = rag_chain.invoke(user_question)
                st.session_state.answer = response
        else:
            st.warning("์งˆ๋ฌธ์„ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”.")

    if "answer" in st.session_state and st.session_state.answer:
        st.subheader("๋‹ต๋ณ€")
        st.markdown(st.session_state.answer)

    st.markdown("---")
    st.caption("ยฉ 2023 WDS QA ๋ด‡. ๋ชจ๋“  ๊ถŒ๋ฆฌ ๋ณด์œ .")


if __name__ == "__main__":
    main()