File size: 10,570 Bytes
afe04d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import os
import time
import random
import torch
import dotenv
import ollama
import logging
import requests
import streamlit as st

from typing import Optional, List
dotenv.load_dotenv()
logger = logging.getLogger(__name__)

# Default prompt context (unchanged)
DEFAULT_PROMPT_CONTEXT = """Bạn là một trợ lí AI pháp luật Việt Nam, có kiến thức về pháp luật Việt Nam.
Dựa vào ngữ cảnh hoặc tài liệu sau hãy trả lời câu hỏi người dùng
Ngữ cảnh: {context}
Câu hỏi: {question}
"""

class Config:
   # URL_RETRIEVE = str(os.getenv("URL_RETRIEVE", "http://202.191.56.254:9002/getPredictionOutput"))
    URL_RETRIEVE = str(os.getenv("URL_RETRIEVE", "http://0.0.0.0:9002/getPredictionOutput"))
    TOP_K = int(os.getenv("TOP_K", 5))


def prompt_model(top_k_chunks: Optional[List]):
    res = []
    for context in top_k_chunks:
        text = ''
        if context.get("diem_id", ""):
            text += "điểm " + context.get("diem_id", "") + " "
            
        if context.get("khoan_id", ""):
            text += "khoản " + str(int(context.get("khoan_id"))) + " "
        
        if context.get("diem_id", ""):
            text += context.get("diem_id", "") + " "
        
        if context.get("law_id", ""):
            if 'ttlt' in context.get("law_id", ""):
                text += "thông tư liên tịch " + context.get("law_id", "").upper() + " "
            elif 'tt' in context.get("law_id", ""):
                text += "thông tư " + context.get("law_id", "").upper() + " "
            elif 'nđ' in context.get("law_id", ""):
                text += "nghị định " + context.get("law_id", "").upper() + " "
            elif 'tb' in context.get("law_id", ""):
                text += "thông báo " + context.get("law_id", "").upper() + " "
            else:
                text += "luật số " + context.get("law_id", "").upper() + " "
            
        if context.get("title", "") and random.choice([1, 0]):
            text += context.get("title", "") + " "

        text += context.get("text", "") + " "
        res.append(text)
    return res[:Config.TOP_K]


def get_retrieval(query, config):
    retrieve = {"query": [query]}
    response = requests.post(config.URL_RETRIEVE, json=retrieve)
    if response.status_code == 200:
        return response.json()['predict'][0][0][0]['top_relevant_chunks']
    else:
        return []


def model_res_generator(prompt_template, retrieval_results, question):
    context = '\n'.join(prompt_model(retrieval_results))  
    input_model = prompt_template.format(context=context, question=question)
    
    stream = ollama.chat(
        model=st.session_state["model"],
        messages=[{'role': 'user', 'content': input_model}],
        options={
            'temperature': 0.0
        },
        stream=True,
    )

    full_response = ""
    for chunk in stream:
        chunk_content = chunk.get("message", {}).get("content", "")
        if chunk_content:
            full_response += chunk_content
            yield full_response


def process_input(prompt_template, user_input, config):
    if user_input:
        st.session_state.messages.append({"role": "user", "content": user_input})
        retrieval_results = get_retrieval(user_input, config)
        retrieval_results.sort(key=lambda x: x['bi_score'], reverse=True)
        st.session_state.retrieval_results = retrieval_results[:st.session_state.top_k]
        st.session_state.queries_and_results[user_input] = st.session_state.retrieval_results
        st.session_state.selected_query = user_input
        st.session_state.query_list = list(st.session_state.queries_and_results.keys())
        st.session_state.messages.append({"role": "assistant", "content": ""})

        for message in st.session_state.messages[-2:-1]:
            with st.chat_message(message["role"]):
                st.markdown(message["content"])

        with st.chat_message("assistant"):
            message_placeholder = st.empty()
            for chunk in model_res_generator(prompt_template, retrieval_results, user_input):
                message_placeholder.markdown(chunk + "▌")
                st.session_state.messages[-1]["content"] = chunk
            message_placeholder.markdown(st.session_state.messages[-1]["content"])

        st.rerun() 


def reset_session_state(config):
    st.session_state.messages = [] 
    st.session_state.retrieval_results = []
    st.session_state.queries_and_results = {}
    st.session_state.selected_query = None
    st.session_state.top_k = config.TOP_K
    st.session_state.query_list = []  


def update_all_queries_results(config):
    """Cập nhật lại kết quả của tất cả các truy vấn khi top_k thay đổi"""
    for query in st.session_state.queries_and_results.keys():
        retrieval_results = get_retrieval(query, config)
        retrieval_results.sort(key=lambda x: x['bi_score'], reverse=True)
        st.session_state.queries_and_results[query] = retrieval_results[:st.session_state.top_k]
    st.rerun()



if __name__ == "__main__":
    st.set_page_config(page_title="AsklexAI", page_icon="🧊", layout="centered")
    config = Config()
    models = [model["name"] for model in ollama.list()["models"]]

    # Initialize session state variables if not already initialized
    if 'query_list' not in st.session_state:
        st.session_state.query_list = []  
    if 'messages' not in st.session_state:
        st.session_state.messages = [] 
    if 'queries_and_results' not in st.session_state:
        st.session_state.queries_and_results = {}
    if 'model' not in st.session_state:
        st.session_state.model = "" 
    if 'retrieval_results' not in st.session_state:
        st.session_state.retrieval_results = []
    if 'selected_query' not in st.session_state:
        st.session_state.selected_query = None
    if 'top_k' not in st.session_state:
        st.session_state.top_k = 5
    if 'custom_prompt' not in st.session_state:
        st.session_state.custom_prompt = DEFAULT_PROMPT_CONTEXT  # Set default custom prompt if not yet set

    st.markdown("""
    <style>
    .css-1aumxhk {
        background-color: #F0F2F6;  /* Light blue-gray background */
    }
    .css-1aumxhk .stMarkdown {
        color: #333;  /* Darker text for better readability */
    }
    .css-1aumxhk .stButton>button {
        background-color: #4A90E2;  /* Blue button color */
        color: white;
    }
    .css-1aumxhk .stSelectbox>div {
        background-color: white;
        border-color: #4A90E2;
    }
    </style>
    """, unsafe_allow_html=True)


    with st.sidebar:
        if st.button("New Chat", use_container_width=True):
            reset_session_state(config)

        st.subheader("Danh sách các truy vấn:")
        selected_query = st.selectbox(
            "Chọn truy vấn",
            options=st.session_state.query_list[::-1] + ["--Chọn truy vấn--"],
            key="query_selectbox"
        )

        if selected_query != "--Chọn truy vấn--":
            st.session_state.selected_query = selected_query

        if st.session_state.selected_query:
            selected_results = st.session_state.queries_and_results[st.session_state.selected_query]
            for i, result in enumerate(selected_results):
                with st.expander(f"Top {i+1}: {result['title'][0].upper() + result['title'][1:50]}... (Score: {result['bi_score']:.2f})"):
                    st.markdown(f"**Văn bản:** {result['law_id']}")
                    st.markdown(f"**Tiêu đề:** {result['title'][0].upper() + result['title'][1:]}")
                    st.markdown(f"**Nội dung:** {result['text']}")

        st.subheader("Cài đặt Retrieve")
        new_top_k = st.slider("Chọn số lượng kết quả top-k:", min_value=1, max_value=30, value=st.session_state.top_k)
        if new_top_k != st.session_state.top_k:
            st.session_state.top_k = new_top_k
            update_all_queries_results(config) 

        st.session_state.model = st.selectbox("Chọn mô hình Ollama", models)
        with st.expander("Custom Prompt", expanded=False):
            custom_prompt = st.text_area("Prompt Template", value=st.session_state.custom_prompt, height=300)
            if st.button("Save"):
                st.session_state.custom_prompt = custom_prompt
                st.success("Prompt template updated.")
                st.caption(st.session_state.custom_prompt)
                
    
    # Use the custom prompt if available, otherwise use the default one
    prompt_template = st.session_state.custom_prompt

    if len(st.session_state.messages) != 0:
        for message in st.session_state.messages:
            with st.chat_message(message["role"]):
                st.markdown(message["content"])
    else:
        with st.chat_message("assistant"):
            intro_text = "Chào bạn! Tôi là trợ lý pháp luật Việt Nam. Tôi có thể giúp bạn trả lời các câu hỏi và tìm kiếm về pháp luật Việt Nam. Nếu có câu hỏi gì xin vui lòng nhắn bên dưới!"
            message_placeholder = st.empty()
            for i in range(len(intro_text) + 1):
                message_placeholder.markdown(intro_text[:i+1] + "▌")
                time.sleep(0.005)
            message_placeholder.markdown(intro_text)
            st.session_state.messages.append({
                "role": "assistant",
                "content": intro_text
            })

    col1, col2, col3 = st.columns(3)
    with col1:
        sg_1 = st.button("Điều kiện áp dụng hợp đồng trong pháp luật Việt Nam?", use_container_width=True)

    with col2:
        sg_2 = st.button("Quy định về bảo vệ quyền lợi người tiêu dùng?", use_container_width=True)

    with col3:
        sg_3 = st.button("Quy trình khiếu nại trong pháp luật Việt Nam?", use_container_width=True)

    
    if sg_1:
        user_input = "Điều kiện áp dụng hợp đồng trong pháp luật Việt Nam?"
        process_input(prompt_template, user_input, config)
    elif sg_2:
        user_input = "Quy định về bảo vệ quyền lợi người tiêu dùng?"
        process_input(prompt_template, user_input, config)
    elif sg_3:
        user_input = "Quy trình khiếu nại trong pháp luật Việt Nam?"
        process_input(prompt_template, user_input, config)

    user_input = st.chat_input("Nhập tin nhắn của bạn")
    if user_input:
        process_input(prompt_template, user_input, config)